{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# self-attention\n",
    "\n",
    "这里仅仅尝试复现论文中的模型，从最简单的 attention开始，到之后的 multi-head attention。Block 的构建，暂时不写了。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. scaled-dot-product-attention\n",
    "\n",
    "![scaled-dot-product-attention](./img/scaled-dot-product-attention.png)\n",
    "\n",
    "$$Attention(\\boldsymbol{Q},\\boldsymbol{K},\\boldsymbol{V}) = softmax\\left(\\frac{\\boldsymbol{Q}\\boldsymbol{K}^{\\top}}{\\sqrt{d_k}}\\right)\\boldsymbol{V}$$\n",
    "\n",
    "通常情况下，$Q, K, V$ 的维数会相等"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import *\n",
    "from torch.nn.parameter import Parameter\n",
    "import numpy as np\n",
    "\n",
    "# 这里不做 mask，同时，我们默认 dk = dv\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "\n",
    "    def __init__(self, d_model):\n",
    "        '''scaled-dot-product-attention\n",
    "            parameters: \n",
    "                d_model: A scalar. attention size\n",
    "        '''\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "        self.temper = np.power(d_model, 0.5)\n",
    "    \n",
    "    def forward(self, Q, K, V):\n",
    "        ''' forward step\n",
    "            parameters: \n",
    "                Q (batch*n*dk)\n",
    "                K (batch*m*dk)\n",
    "                V (batch*m*dv)\n",
    "            note: dv == dk\n",
    "        '''\n",
    "        qk = torch.bmm(Q, K.transpose(1, 2)) # (batch*n*dk) x (batch*dk*m) -> batch*n*m\n",
    "        weight = F.softmax(qk / self.temper, dim=1) # batch*n*m -> batch*n*m\n",
    "        attention_V = torch.matmul(weight, V) # (batch*n*m) x (batch*m*dv) -> batch*n*dv\n",
    "        return attention_V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 10, 32])\n"
     ]
    }
   ],
   "source": [
    "# test \n",
    "# batch = 16, n = 10, m = 20, dk = dv = 32\n",
    "# input: Q (16*10*32), K(16*20*32), V(16*20*32)\n",
    "# output: attention_V (16*10*32)\n",
    "\n",
    "sdpt_model = ScaledDotProductAttention(d_model=32)\n",
    "\n",
    "Q1 = Variable(torch.randn(16, 10, 32))\n",
    "K1 = Variable(torch.randn(16, 20, 32))\n",
    "V1 = Variable(torch.randn(16, 20, 32))\n",
    "\n",
    "attention_V = sdpt_model(Q1, K1, V1)\n",
    "print(attention_V.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Multi-Head Attention\n",
    "\n",
    "![multi-head-attention](./img/multi-head-attention.png)\n",
    "\n",
    "\n",
    "\n",
    "先对 $Q, K, V$ 进行矩阵映射，然后再扔进前面的 scaled-dot-product-attention network中，最后把这个过程重复做 $h$ 次，就是 multi-head-attention network\n",
    "\n",
    "$$head_i = Attention(\\boldsymbol{Q}\\boldsymbol{W}_i^Q,\\boldsymbol{K}\\boldsymbol{W}_i^K,\\boldsymbol{V}\\boldsymbol{W}_i^V)$$\n",
    "\n",
    "其中 $\\boldsymbol{W}_i^Q\\in\\mathbb{R}^{d_k\\times \\tilde{d}_k}, \\boldsymbol{W}_i^K\\in\\mathbb{R}^{d_k\\times \\tilde{d}_k}, \\boldsymbol{W}_i^V\\in\\mathbb{R}^{d_v\\times \\tilde{d}_v}$\n",
    "\n",
    "$$MultiHead(\\boldsymbol{Q},\\boldsymbol{K},\\boldsymbol{V}) = Concat(head_1,...,head_h)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "# code reference from https://github.com/jadore801120/attention-is-all-you-need-pytorch\n",
    "# \n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import *\n",
    "from torch.nn.parameter import Parameter\n",
    "import numpy as np\n",
    "import math\n",
    "\n",
    "# 使用残差进行链接, no mask\n",
    "class MultiHeadAttention(nn.Module):\n",
    "\n",
    "    def __init__(self, d_model, d_k_hat, d_v_hat, n_head=8, dropout_rate=0, mask=False):\n",
    "        '''multi-head-attention.\n",
    "            parameters:\n",
    "                d_model: A scalar. attention size.\n",
    "                d_k_hat: A scalar. linear project dimension of k.\n",
    "                d_v_hat: A scalar. linear project dimension of v.\n",
    "                num_heads: An int. Number of heads.\n",
    "                dropout_rate: A floating point number. drop_ou\n",
    "        '''\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        \n",
    "        self.n_head = n_head\n",
    "        self.d_k_hat = d_k_hat # 通常 d_k_hat = d_model / n_head\n",
    "        self.d_v_hat = d_v_hat # 通常 d_v_hat = d_model / n_head\n",
    "        \n",
    "        self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k_hat))\n",
    "        self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k_hat))\n",
    "        self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v_hat))\n",
    "        \n",
    "        self.attention_net = ScaledDotProductAttention(d_model)\n",
    "        \n",
    "        self.linear_proj = torch.nn.Linear(n_head*d_v_hat, d_model)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout_rate)\n",
    "        \n",
    "        self.mask = mask\n",
    "\n",
    "    def forward(self, Q, K, V):\n",
    "        ''' forward step\n",
    "            parameters: Q (batch*n*d_model), K(batch*m*d_model), V(batch*m*d_model)\n",
    "        '''\n",
    "        d_k_hat, d_v_hat = self.d_k_hat, self.d_v_hat\n",
    "        \n",
    "        residual = Q # batch_size x len_q x d_model\n",
    "        \n",
    "        n_head = self.n_head\n",
    "        \n",
    "        batch_size, len_q, d_model = Q.size()\n",
    "        batch_size, len_k, d_model = K.size()\n",
    "        batch_size, len_v, d_model = V.size()\n",
    "        \n",
    "        # 重复 multi-head 次，方便之后进行线性变换\n",
    "        q_s = Q.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head*(batch_size*len_q)*d_model\n",
    "        k_s = K.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head*(mb_size*len_k)*d_model\n",
    "        v_s = V.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head*(mb_size*len_v)*d_model\n",
    "        \n",
    "        # 线性变换\n",
    "        # bmm: (n_head*(batch_size*len_q)*d_model) x (n_head*d_model*d_k_hat) -> n_head*(batch_size*len_q)*d_k_hat\n",
    "        # view: n_head*(batch_size*len_q)*d_k_hat -> (n_head*batch_size)*len_q*d_k_hat\n",
    "        q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k_hat) \n",
    "        k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k_hat)\n",
    "        v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v_hat)\n",
    "        \n",
    "        # 扔进 Attention network 中\n",
    "        outputs = self.attention_net(q_s, k_s, v_s) # (n_head*batch_size)*len_q*d_v_hat\n",
    "        \n",
    "        # concatenate 操作，复原到  batch_size x len_q x (n_head*d_v_hat)\n",
    "        # split: (n_head*batch_size)*len_q*d_v_hat ->  n_head 个 [batch_size*len_q*d_v_hat]\n",
    "        # cat: n_head 个 [batch_size*len_q*d_v_hat] -> batch_size x len_q x (n_head*d_v_hat)\n",
    "        outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1)\n",
    "        \n",
    "        # 最后一个 linear layer\n",
    "        outputs = self.linear_proj(outputs) # batch_size x len_q x (n_head*d_v_hat) -> batch_size x len_q x d_model\n",
    "        outputs = self.dropout(outputs)\n",
    "        \n",
    "        # 残差\n",
    "        return outputs + residual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 10, 32])\n"
     ]
    }
   ],
   "source": [
    "# test \n",
    "# batch = 16, n = 10, m = 20, dk = dv = 32\n",
    "# input: Q (16*10*32), K(16*20*32), V(16*20*32)\n",
    "# output: attention_V (16*10*32)\n",
    "\n",
    "mha_model = MultiHeadAttention(d_model=32, d_k_hat=4, d_v_hat=4, n_head=8, dropout_rate=0)\n",
    "\n",
    "Q2 = Variable(torch.randn(16, 10, 32))\n",
    "K2 = Variable(torch.randn(16, 20, 32))\n",
    "V2 = Variable(torch.randn(16, 20, 32))\n",
    "\n",
    "attention_V2 = mha_model(Q2, K2, V2)\n",
    "print(attention_V2.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Self-Attention\n",
    "\n",
    "我们让 $Q = K = V$ 扔进 Multi-Head-Attention network 中得到的就是 self-attention\n",
    "\n",
    "$$\\boldsymbol{Y}=MultiHead(\\boldsymbol{X},\\boldsymbol{X},\\boldsymbol{X})$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结\n",
    "\n",
    "其实还有几个 layer_norm、 Mask 没有去实现， 不过这里并没有去实现，所以，等以后需要用到了，再去整吧。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
